{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transformer for CIFAR10\n",
    "\n",
    "A configurable transformer model will be used for CIFAR10 image classification. \n",
    "\n",
    "The vision transformer model is a modified version of ViT. The changes are:\n",
    "1) No position embedding.\n",
    "2) No dropout is used.\n",
    "3) All encoder features are used for class prediction.\n",
    "\n",
    "The code below is a simplified version of Timm modules.\n",
    "\n",
    "Let us import the required packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchmetrics\n",
    "from argparse import ArgumentParser\n",
    "from pytorch_lightning import LightningModule, Trainer, LightningDataModule\n",
    "from torch.optim import Adam\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "\n",
    "from einops import rearrange\n",
    "from torch import nn\n",
    "from torchvision.datasets.cifar import CIFAR10"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attention Module\n",
    "\n",
    "The `Attention` module is the core of the vision transformer model. It implements the attention mechanism:\n",
    "\n",
    "1) Multiply QKV by their weights\n",
    "2) Perform dot product on Q and K. \n",
    "3) Normalize the result in 2) by sqrt of `head_dim`  \n",
    "4) Softmax is applied to the result.\n",
    "5) Perform dot product on the result of 4) and V and the result is the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Attention(nn.Module):\n",
    "    def __init__(self, dim, num_heads=8, qkv_bias=False):\n",
    "        super().__init__()\n",
    "        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n",
    "        self.num_heads = num_heads\n",
    "        head_dim = dim // num_heads\n",
    "        self.scale = head_dim ** -0.5\n",
    "\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, N, C = x.shape\n",
    "        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n",
    "        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)\n",
    "\n",
    "        attn = (q @ k.transpose(-2, -1)) * self.scale\n",
    "        attn = attn.softmax(dim=-1)\n",
    "\n",
    "        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n",
    "        x = self.proj(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MLP Module\n",
    "\n",
    "The MLP module is a made of two linear layers. A non-linear activation is applied to the output of the first layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mlp(nn.Module):\n",
    "    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n",
    "    \"\"\"\n",
    "    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):\n",
    "        super().__init__()\n",
    "        out_features = out_features or in_features\n",
    "        hidden_features = hidden_features or in_features\n",
    "      \n",
    "        self.fc1 = nn.Linear(in_features, hidden_features)\n",
    "        self.act = act_layer()\n",
    "        self.fc2 = nn.Linear(hidden_features, out_features)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.act(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Block Module\n",
    "\n",
    "The `Block` module represents one encoder transformer block. It consists of two sub-modules:\n",
    "1) The Attention module\n",
    "2) The MLP module\n",
    "\n",
    "Layer norm is applied before and after the Attention module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "\n",
    "    def __init__(\n",
    "            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, \n",
    "            act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n",
    "        super().__init__()\n",
    "        self.norm1 = norm_layer(dim)\n",
    "        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias) \n",
    "        self.norm2 = norm_layer(dim)\n",
    "        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) \n",
    "   \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.attn(self.norm1(x))\n",
    "        x = x + self.mlp(self.norm2(x))\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Transformer Module\n",
    "\n",
    "The feature encoder is made of several transformer blocks. The most important attributes are:\n",
    "1) `depth` : representing the number of encoder blocks\n",
    "2) `num_heads` : representing the number of attention heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Transformer(nn.Module):\n",
    "    def __init__(self, dim, num_heads, num_blocks, mlp_ratio=4., qkv_bias=False,  \n",
    "                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n",
    "        super().__init__()\n",
    "        self.blocks = nn.ModuleList([Block(dim, num_heads, mlp_ratio, qkv_bias, \n",
    "                                     act_layer, norm_layer) for _ in range(num_blocks)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        for block in self.blocks:\n",
    "            x = block(x)\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### The optional parameter initialization as adopted from `timm`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weights_vit_timm(module: nn.Module):\n",
    "    \"\"\" ViT weight initialization, original timm impl (for reproducibility) \"\"\"\n",
    "    if isinstance(module, nn.Linear):\n",
    "        nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)\n",
    "        if module.bias is not None:\n",
    "            nn.init.zeros_(module.bias)\n",
    "    elif hasattr(module, 'init_weights'):\n",
    "        module.init_weights()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch Lightning for CIFAR10 Image Classification\n",
    "\n",
    "We use the `Transformer` module to build the feature encoder. Before the `Transformer` can be used, we convert the input image into patches. The patches are then embedded into a linear space. The output is then passed to the Transformer.\n",
    "\n",
    "Another difference between this model is we use all output features for the final classification. In the ViT, only the first feature is used.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class LitTransformer(LightningModule):\n",
    "    def __init__(self, num_classes=10, lr=0.001, max_epochs=30, depth=12, embed_dim=64,\n",
    "                 head=4, patch_dim=192, seqlen=16, **kwargs):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters()\n",
    "        self.encoder = Transformer(dim=embed_dim, num_heads=head, num_blocks=depth, mlp_ratio=4.,\n",
    "                                   qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm)\n",
    "        self.embed = torch.nn.Linear(patch_dim, embed_dim)\n",
    "\n",
    "        self.fc = nn.Linear(seqlen * embed_dim, num_classes)\n",
    "        self.loss = torch.nn.CrossEntropyLoss()\n",
    "        \n",
    "        self.reset_parameters()\n",
    "        self.accuracy = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n",
    "\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        init_weights_vit_timm(self)\n",
    "    \n",
    "\n",
    "    def forward(self, x):\n",
    "        # Linear projection\n",
    "        x = self.embed(x)\n",
    "            \n",
    "        # Encoder\n",
    "        x = self.encoder(x)\n",
    "        x = x.flatten(start_dim=1)\n",
    "\n",
    "        # Classification head\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = Adam(self.parameters(), lr=self.hparams.lr)\n",
    "        # this decays the learning rate to 0 after max_epochs using cosine annealing\n",
    "        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)\n",
    "        return [optimizer], [scheduler]\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        loss = self.loss(y_hat, y)\n",
    "        return loss\n",
    "    \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        loss = self.loss(y_hat, y)\n",
    "        acc = self.accuracy(y_hat, y)\n",
    "        return {\"y_hat\": y_hat, \"test_loss\": loss, \"test_acc\": acc}\n",
    "\n",
    "    def test_epoch_end(self, outputs):\n",
    "        avg_loss = torch.stack([x[\"test_loss\"] for x in outputs]).mean()\n",
    "        avg_acc = torch.stack([x[\"test_acc\"] for x in outputs]).mean()\n",
    "        self.log(\"test_loss\", avg_loss, on_epoch=True, prog_bar=True)\n",
    "        self.log(\"test_acc\", avg_acc*100., on_epoch=True, prog_bar=True)\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        return self.test_step(batch, batch_idx)\n",
    "    \n",
    "    def on_validation_epoch_end(self):\n",
    "        self.log(\"val_acc_epoch\", self.accuracy.compute()*100., on_epoch=True, prog_bar=True)\n",
    "\n",
    "\n",
    "# a lightning data module for cifar 10 dataset\n",
    "class LitCifar10(LightningDataModule):\n",
    "    def __init__(self, batch_size=32, num_workers=32, patch_num=4, **kwargs):\n",
    "        super().__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.patch_num = patch_num\n",
    "        self.num_workers = num_workers\n",
    "\n",
    "    def prepare_data(self):\n",
    "        self.train_set = CIFAR10(root='~/data', train=True,\n",
    "                                 download=True, transform=torchvision.transforms.ToTensor())\n",
    "        self.test_set = CIFAR10(root='~/data', train=False,\n",
    "                                download=True, transform=torchvision.transforms.ToTensor())\n",
    "\n",
    "    def collate_fn(self, batch):\n",
    "        x, y = zip(*batch)\n",
    "        x = torch.stack(x, dim=0)\n",
    "        y = torch.LongTensor(y)\n",
    "        x = rearrange(x, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=self.patch_num, p2=self.patch_num)\n",
    "        return x, y\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, \n",
    "                                        shuffle=True, collate_fn=self.collate_fn,\n",
    "                                        num_workers=self.num_workers)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, \n",
    "                                        shuffle=False, collate_fn=self.collate_fn,\n",
    "                                        num_workers=self.num_workers)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        return self.test_dataloader()\n",
    "\n",
    "\n",
    "\n",
    "def get_args():\n",
    "    parser = ArgumentParser(description='PyTorch Transformer')\n",
    "    parser.add_argument('--depth', type=int, default=12, help='depth')\n",
    "    parser.add_argument('--embed_dim', type=int, default=64, help='embedding dimension')\n",
    "    parser.add_argument('--num_heads', type=int, default=4, help='num_heads')\n",
    "\n",
    "    parser.add_argument('--patch_num', type=int, default=8, help='patch_num')\n",
    "    parser.add_argument('--kernel_size', type=int, default=3, help='kernel size')\n",
    "    parser.add_argument('--batch_size', type=int, default=64, metavar='N',\n",
    "                        help='input batch size for training (default: )')\n",
    "    parser.add_argument('--max-epochs', type=int, default=10, metavar='N',\n",
    "                        help='number of epochs to train (default: 0)')\n",
    "    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',\n",
    "                        help='learning rate (default: 0.0)')\n",
    "\n",
    "    parser.add_argument('--accelerator', default='gpu', type=str, metavar='N')\n",
    "    parser.add_argument('--devices', default=1, type=int, metavar='N')\n",
    "    parser.add_argument('--dataset', default='cifar10', type=str, metavar='N')\n",
    "    parser.add_argument('--num_workers', default=4, type=int, metavar='N')\n",
    "    args = parser.parse_args(\"\")\n",
    "    return args\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performance on different settings\n",
    "\n",
    "The following table shows different performances on different settings. Generally, Transformer is better than MLP in terms of accuracy and parameter count. However, the performance is worse compared to CNN models. \n",
    "\n",
    "However, the most important thing to note is that Transformers are more general purpose models than CNNs. They can process different types of data. They can process multiple types of data at the same time. This is why they are considered to be the backbone of many high-performing models like BERT, GPT3, PalM and Gato.\n",
    "\n",
    "| **Depth** | **Head** | **Embed dim** | **Patch size** | **Seq len** | **Params** | **Accuracy** | \n",
    "| -: | -: | -: | -: | -: | -: | -: |\n",
    "| 12 | 4 | 32 | 4x4 | 64 | 173k | 68.2% |\n",
    "| 12 | 4 | 64 | 4x4 | 64 | 641k | 71.1% |\n",
    "| 12 | 4 | 128 | 4x4 | 64 | 2.5M | 71.5% |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 16bit Automatic Mixed Precision (AMP)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 48])\n",
      "Embed dim: 64\n",
      "Patch size: 4\n",
      "Sequence length: 64\n",
      "Number of model parameters: 0.64 million\n",
      "Running on: gpu\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
      "\n",
      "  | Name     | Type               | Params\n",
      "------------------------------------------------\n",
      "0 | encoder  | Transformer        | 597 K \n",
      "1 | embed    | Linear             | 3.1 K \n",
      "2 | fc       | Linear             | 41.0 K\n",
      "3 | loss     | CrossEntropyLoss   | 0     \n",
      "4 | accuracy | MulticlassAccuracy | 0     \n",
      "------------------------------------------------\n",
      "641 K     Trainable params\n",
      "0         Non-trainable params\n",
      "641 K     Total params\n",
      "2.566     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e5a66dc192cc4e5f95f4f61b913004d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aebaa235c644495ca83e422eeb8465b3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e99c3040905440b6a97347f49a6f80af",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0e855372393454fa948a2158876ccf9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8c14438f0e74f32ac0d29751866f08d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c0c772543544df0bab5fc2f71860208",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "169b8b816ec0477e89334c26148415ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65e907ba8e244f6c932c4c5d812666a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81449ea0bb0c419484a4c6961db3ce0f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89b33a839c564f4ea34f8d8921965efb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1778c3fb1a1a4c488922fc8bb1a76e25",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9956b7c2a917492cbd0a0c95eb711ad7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=10` reached.\n"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    args = get_args()\n",
    "\n",
    "    datamodule = LitCifar10(batch_size=args.batch_size,\n",
    "                            patch_num=args.patch_num, \n",
    "                            num_workers=args.num_workers * args.devices)\n",
    "    datamodule.prepare_data()\n",
    "\n",
    "    sample_data = next(iter(datamodule.train_dataloader()))\n",
    "    data = sample_data[0][0]\n",
    "    print(data.shape)\n",
    "    \n",
    "\n",
    "    #data = iter(datamodule.train_dataloader()).next()\n",
    "    patch_dim = data.shape[-1]\n",
    "    seqlen = data.shape[-2]\n",
    "    print(\"Embed dim:\", args.embed_dim)\n",
    "    print(\"Patch size:\", 32 // args.patch_num)\n",
    "    print(\"Sequence length:\", seqlen)\n",
    "\n",
    "\n",
    "    model = LitTransformer(num_classes=10, lr=args.lr, epochs=args.max_epochs, \n",
    "                           depth=args.depth, embed_dim=args.embed_dim, head=args.num_heads,\n",
    "                           patch_dim=patch_dim, seqlen=seqlen,)\n",
    "    \n",
    "    print(f\"Number of model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f} million\")\n",
    "    print(f\"Running on: {args.accelerator}\")\n",
    "\n",
    "    trainer = Trainer(accelerator=args.accelerator, devices=args.devices,\n",
    "                      max_epochs=args.max_epochs, precision=16 if args.accelerator == 'gpu' else 32,)\n",
    "    trainer.fit(model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "52772734322c44a04e342c358be70f2ff4d97da358cb3cd38ceb0f6066598be5"
  },
  "kernelspec": {
   "display_name": "Python 3.7.3 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
